{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use Mistral LLM for Text Classification and Entity Recognition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "toc": true
   },
   "source": [
    "<h1>Table of Content<span class=\"tocSkip\"></span></h1>\n",
    "<div class=\"toc\">\n",
    "   <ul class=\"toc-item\">\n",
    "      <li>\n",
    "         <span><a href=\"#about-mistral-model\" data-toc-modified-id=\"about-mistral-model\"><span class=\"toc-item-num\"></span>Introduction to Mistral model</a></span>\n",
    "      </li>\n",
    "   </ul>\n",
    "   <ul class=\"toc-item\">\n",
    "      <li>\n",
    "         <span><a href=\"#mistral_integration_in_arcgis_learn\" data-toc-modified-id=\"mistral-integration-in-arcgis-learn\"><span class=\"toc-item-num\"></span>Mistral Implementation in <code>arcgis.learn</code></a></span>\n",
    "      </li>\n",
    "      <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#install_the_model_backbone\" data-toc-modified-id=\"install_the_model_backbone\"><span class=\"toc-item-num\"></span>Install the model backbone</a></span></li>\n",
    "      </ul>\n",
    "   </ul>\n",
    "   <ul class=\"toc-item\">\n",
    "      <li>\n",
    "         <span><a href=\"#mistral_with_textclassifier_model\" data-toc-modified-id=\"mistral_with_textclassifier_model\"><span class=\"toc-item-num\"></span>Mistral with the <code>TextClassifier</code> model</a></span>\n",
    "      </li>\n",
    "      <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#import_mistral_text_classifier\" data-toc-modified-id=\"import_mistral_text_classifier\"><span class=\"toc-item-num\"></span> Import the <code>TextClassifier</code> class from the <code>arcgis.learn.text</code> module</a></span></li>\n",
    "      </ul>\n",
    "       <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#initialize_textclassifier_model_databunch\" data-toc-modified-id=\"initialize_textclassifier_model_databunch\"><span class=\"toc-item-num\"></span>Initialize the <code>TextClassifier</code> model with a databunch</a></span></li>\n",
    "      </ul>\n",
    "       <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#initialize_textclassifier_model_nodatabunch\" data-toc-modified-id=\"initialize_textclassifier_model_nodatabunch\"><span class=\"toc-item-num\"></span> Initialize the<code>TextClassifier</code> model without a databunch</a></span></li>\n",
    "      </ul>\n",
    "       <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#classify_text_using_mistral_model\" data-toc-modified-id=\"classify_text_using_mistral_model\"><span class=\"toc-item-num\"></span>Classify the text using mistral model</a></span></li>\n",
    "      </ul>\n",
    "       <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#load_model\" data-toc-modified-id=\"load_model\"><span class=\"toc-item-num\"></span>Load the model</a></span></li>\n",
    "      </ul>\n",
    "       <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#save_model\" data-toc-modified-id=\"save_model\"><span class=\"toc-item-num\"></span>Save the model</a></span></li>\n",
    "      </ul>\n",
    "   </ul>\n",
    "    <ul class=\"toc-item\">\n",
    "      <li>\n",
    "         <span><a href=\"#mistral_with_entityrecognizer_model\" data-toc-modified-id=\"mistral_with_entityrecognizer_model\"><span class=\"toc-item-num\"></span>Mistral with an<code>EntityRecognizer</code> model</a></span>\n",
    "      </li>\n",
    "      <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#import_mistral_entity_recognizer\" data-toc-modified-id=\"import_mistral_entity_recognizer\"><span class=\"toc-item-num\"></span> Import the <code>EntityRecognizer</code> class from the <code>arcgis.learn.text</code> module</a></span></li>\n",
    "      </ul>\n",
    "       <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#initialize_entity_recognizer_model_databunch\" data-toc-modified-id=\"initialize_entity_recognizer_model_databunch\"><span class=\"toc-item-num\"></span>Initialize the <code>EntityRecognizer</code> model with a databunch</a></span></li>\n",
    "      </ul>\n",
    "       <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#initialize_entity_recognizer_model_nodatabunch\" data-toc-modified-id=\"initialize_entity_recognizer_model_nodatabunch\"><span class=\"toc-item-num\"></span> Initialize the <code>EntityRecognizer</code> model without a databunch</a></span></li>\n",
    "      </ul>\n",
    "       <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#extract_entities_using_mistral_model\" data-toc-modified-id=\"extract_entities_using_mistral_model\"><span class=\"toc-item-num\"></span>Extract entities using the mistral model</a></span></li>\n",
    "      </ul>\n",
    "       <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#load_model_er\" data-toc-modified-id=\"load_model_er\"><span class=\"toc-item-num\"></span>Load the model</a></span></li>\n",
    "      </ul>\n",
    "       <ul class=\"toc-item\">\n",
    "         <li><span><a href=\"#save_model_er\" data-toc-modified-id=\"save_model_er\"><span class=\"toc-item-num\"></span>Save the model</a></span></li>\n",
    "      </ul>\n",
    "   </ul>\n",
    "    <ul class=\"toc-item\">\n",
    "      <li>\n",
    "         <span><a href=\"#conclusion\" data-toc-modified-id=\"conclusion\"><span class=\"toc-item-num\"></span>Conclusion</a></span>\n",
    "      </li>\n",
    "   </ul>\n",
    "    <ul class=\"toc-item\">\n",
    "      <li>\n",
    "         <span><a href=\"#reference\" data-toc-modified-id=\"reference\"><span class=\"toc-item-num\"></span>References</a></span>\n",
    "      </li>\n",
    "   </ul>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Introduction to Mistral model <a id=\"about-mistral-model\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Mistral 7B is a decoder-based language model trained using almost 7 billion parameters designed to deliver both efficiency and high performance for real-world applications.\n",
    "\n",
    "Employing attention mechanisms like Sliding Window Attention, Mistral 7B can train with an 8k context length and a fixed cache size, resulting in a theoretical attention span of 128K tokens. This capability allows the model to focus on crucial parts of the text efficiently. Moreover, the model incorporates Grouped Query Attention (GQA) to accelerate inference and reduce cache size, thereby expediting its inference process. Additionally, its Byte-fallback tokenizer ensures consistent representation of characters, eliminating the need for out-of-vocabulary tokens.\n",
    "\n",
    "Such design features in its architecture equip Mistral 7B for exceptional performance, particularly in tasks related to language comprehension and generation. In this guide we see how we can use the Mistral LLM for text classification and named entity recognition."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mistral Implementation in arcgis.learn <a id=\"mistral_integration_in_arcgis_learn\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install the model backbone <a id=\"install_the_model_backbone\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Follow these steps to download and install the Mistral model backbone:\n",
    "\n",
    "1. Download the <a href=\"https://esri.maps.arcgis.com/home/item.html?id=969d2fc57c834295af1a4a42cfd51f68\">mistral model backbone</a>.\n",
    "\n",
    "2. Extract the downloaded zip file.\n",
    "\n",
    "3. Open the anaconda prompt and move to the folder that contains arcgis_mistral_backbone-1.0.0-py_0.tar.bz2\n",
    "\n",
    "4. Run:\n",
    "\n",
    " ```conda install --offline arcgis_mistral_backbone-1.0.0-py_0.tar.bz2```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mistral with the TextClassifier model <a id=\"mistral_with_textclassifier_model\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import the TextClassifier class from the arcgis.learn.text module <a id=\"import_mistral_text_classifier\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from arcgis.learn.text import TextClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize the TextClassifier model with a databunch <a id=\"initialize_textclassifier_model_databunch\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare databunch for the <code>TextClassifier</code> model using the<code>prepare_textdata</code> method in <code>arcgis.learn</code>."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from arcgis.learn import prepare_textdata\n",
    "data = prepare_textdata(\"path_to_data_folder\", task=\"classification\",train_file=\"input_csv_file.csv\",\n",
    "                        text_columns=\"text_column\", label_columns=\"label_column\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the data is prepared, the <code>TextClassifier</code> model object can be instantiated as below with the following parameters:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<code>data</code>: The databunch created using the <i>prepare_textdata</i> method.\n",
    "\n",
    "<code>backbone</code>: To use mistral as the model backbone, use <i>backbone=\"mistral\"</i>.\n",
    "\n",
    "<code>prompt</code>: Text string describing the task and its guardrails. This is an optional parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_model = TextClassifier(\n",
    "    data=data,\n",
    "    backbone=\"mistral\",\n",
    "    prompt=\"Classify all the input sentences into the defined labels, do not make up your own labels.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize the TextClassifier model without a databunch <a id=\"initialize_textclassifier_model_nodatabunch\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A <code>TextClassifier</code> model with a mistral backbone can also be created without a large dataset using only a few examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below are the parameters to be passed into <code>TextClassifier</code>:\n",
    "\n",
    "<code>backbone</code>: To use mistral as the model backbone, use <i>backbone=\"mistral\"</i>.\n",
    "\n",
    "<code>examples</code>: User defined examples to provide the mistral model, in python dictionary format:\n",
    "```\n",
    "{\n",
    "    \"label_1\" :[\"input_text_example_1\", \"input_text_example_2\", ...],\n",
    "    \"label_2\" :[\"input_text_example_1\", \"input_text_example_2\", ...],\n",
    "    ...\n",
    "}\n",
    "```\n",
    "    \n",
    "<code>prompt</code>: Text string describing the task and its guardrails. This is an optional parameter.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_model = TextClassifier(\n",
    "    data=None,\n",
    "    backbone=\"mistral\",\n",
    "    prompt=\"Classify all the input sentences into the defined labels, do not make up your own labels.\",\n",
    "    examples={\n",
    "                \"positive\" : [\" Good! it was a wonderful experience!\", \"i really adore your work\"],\n",
    "                \"negative\" : [\"The customer support was unhelpful\", \"I don`t like your work\"]\n",
    "            }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classify the text using mistral model <a id=\"classify_text_using_mistral_model\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To classify text using the mistral model, use the <code>predict</code> method from the <code>TextClassifier</code> class. The input to the method will be a text string or a list of text string."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_model.predict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the model <a id=\"load_model\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To load a saved mistral model, use the <code>from_model</code> method from the <code>TextClassifier</code> class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_model.from_model(r'path_to_dlpk_file')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save the model <a id=\"save_model\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following method saves the model weights and creates a Deep Learning Package (.dlpk)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_model.save(\"name_of_the_mistral_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mistral with an EntityRecognizer model <a id=\"mistral_with_entityrecognizer_model\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import the EntityRecognizer class from the arcgis.learn.text module <a id=\"import_mistral_entity_recognizer\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from arcgis.learn.text import EntityRecognizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize the EntityRecognizer model with a databunch <a id=\"initialize_entity_recognizer_model_databunch\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare the databunch for the <code>EntityRecognizer</code> model using the <code>prepare_textdata</code> method in arcgis.learn."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from arcgis.learn import prepare_textdata\n",
    "data = prepare_textdata(\"path_to_data_file\", task=\"entity_recognition\", dataset_type='ner_json')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the data is prepared, the <code>EntityRecognizer</code> model object can be instantiated with the following parameters:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<code>data</code>: The databunch created using the <code>prepare_textdata</code> method.\n",
    "\n",
    "<code>backbone</code>: To use mistral as the model backbone, use <i>backbone=\"mistral\"</i>.\n",
    "\n",
    "<code>prompt</code>: Text string describing the task and its guardrails. This is an optional parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_recognizer_model = EntityRecognizer(\n",
    "    data=data,\n",
    "    backbone=\"mistral\",\n",
    "    prompt=\"Tag the input sentences in the named entity for the given classes, no other class should be tagged.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize the EntityRecognizer model without a databunch <a id=\"initialize_entity_recognizer_model_nodatabunch\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An <code>EntityRecognizer</code> model with a mistral backbone can also be created without a large dataset by using only a few examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below are the parameters to be passed into <code>EntityRecognizer</code> :\n",
    "\n",
    "<code>backbone</code>: To use mistral as the model backbone, use backbone=\"mistral\".\n",
    "\n",
    "<code>examples</code>: User defined examples for the mistral model, in python list format:\n",
    "\n",
    "```\n",
    "[\n",
    "    (\"input_text_sentence\", \n",
    "         {\n",
    "             \"class_1\":[\"Named Entity\", ...],\n",
    "             \"class_2\": [\"Named entity\", ...],\n",
    "             ...\n",
    "         }\n",
    "    )\n",
    "    ...\n",
    "]\n",
    "```\n",
    "\n",
    "```\n",
    "Note: The EntityRecognizer class, using the \"Mistral\" backbone, needs at least six examples to work effectively.\n",
    "```\n",
    "    \n",
    "<code>prompt</code>: Text string describing the task and its guardrails. This is an optional parameter.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_recognizer_model = EntityRecognizer(\n",
    "    data=None,\n",
    "    backbone=\"mistral\",\n",
    "    prompt=\"Tag the input sentences in the named entity for the given classes, no other class should be tagged.\"\n",
    "    examples=[(\n",
    "            'Jim stays in London',\n",
    "            {\n",
    "                'name': ['Jim'], \n",
    "                'location': ['London']\n",
    "            },\n",
    "            ...\n",
    "        )]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract entities using the mistral model <a id=\"extract_entities_using_mistral_model\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To extract named entities using the mistral model, use the <code>extract_entities</code> method from the <code>EntityRecognizer</code> class. The input to the method will be a text string or a list of text strings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_recognizer_model.extract_entities()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the model <a id=\"load_model_er\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To load a saved mistral model, use the <code>from_model</code> method from the <code>EntityRecognizer</code> class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_recognizer_model.from_model(r'path_to_dlpk_file')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save the model <a id=\"save_model_er\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This method saves the model weights and creates a Deep Learning Package (.dlpk)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_recognizer_model.save(\"name_of_the_mistral_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion <a id=\"conclusion\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this guide we demonstrated the steps to initialize and perform inference using the Mistral LLM as a backbone with the TextClassifier and EntityRecognizer models in arcgis.learn."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References <a id=\"reference\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Mistral-7B HuggingFace: <a href=\"https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2\">https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2</a><br/>\n",
    "Mistral-7B MistralAI: <a href=\"https://mistral.ai/news/announcing-mistral-7b/\">https://mistral.ai/news/announcing-mistral-7b</a>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
